{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "48518923",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/retrievers/ensemble_retrieval.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bf1de44-4047-46cf-a04c-dbf910d9e179",
   "metadata": {},
   "source": [
    "# Ensemble Retrieval Guide\n",
    "\n",
    "Oftentimes when building a RAG applications there are many retreival parameters/strategies to decide from (from chunk size to vector vs. keyword vs. hybrid search, for instance).\n",
    "\n",
    "Thought: what if we could try a bunch of strategies at once, and have any AI/reranker/LLM prune the results?\n",
    "\n",
    "This achieves two purposes:\n",
    "- Better (albeit more costly) retrieved results by pooling results from multiple strategies, assuming the reranker is good\n",
    "- A way to benchmark different retrieval strategies against each other (w.r.t reranker)\n",
    "\n",
    "This guide showcases this over the Llama 2 paper. We do ensemble retrieval over different chunk sizes and also different indices.\n",
    "\n",
    "**NOTE**: A closely related guide is our [Ensemble Query Engine Guide](https://gpt-index.readthedocs.io/en/stable/examples/query_engine/ensemble_qury_engine.html) - make sure to check it out! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa2b0a5a-6449-4485-8217-b252fa47720e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e73fead-ec2c-4346-bd08-e183c13c7e29",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Here we define the necessary imports."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0212306c",
   "metadata": {},
   "source": [
    "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8dd9c35",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2d59778-4cda-47b5-8cd0-b80fee91d1e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: This is ONLY necessary in jupyter notebook.\n",
    "# Details: Jupyter runs an event-loop behind the scenes.\n",
    "#          This results in nested event-loops when we start an event-loop to make async queries.\n",
    "#          This is normally not allowed, we use nest_asyncio to allow it for convenience.\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c628448c-573c-4eeb-a7e1-707fe8cc575c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: NumExpr detected 12 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
      "NumExpr defaulting to 8 threads.\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "import sys\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "logging.getLogger().handlers = []\n",
    "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n",
    "\n",
    "from llama_index import (\n",
    "    VectorStoreIndex,\n",
    "    SummaryIndex,\n",
    "    SimpleDirectoryReader,\n",
    "    ServiceContext,\n",
    "    StorageContext,\n",
    ")\n",
    "from llama_index.response.notebook_utils import display_response\n",
    "from llama_index.llms import OpenAI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "787174ed-10ce-47d7-82fd-9ca7f891eea7",
   "metadata": {},
   "source": [
    "## Load Data\n",
    "\n",
    "In this section we first load in the Llama 2 paper as a single document. We then chunk it multiple times, according to different chunk sizes. We build a separate vector index corresponding to each chunk size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73a003b8-1d4a-4faf-9402-46d5977bb28d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-09-28 12:56:38--  https://arxiv.org/pdf/2307.09288.pdf\n",
      "Resolving arxiv.org (arxiv.org)... 128.84.21.199\n",
      "Connecting to arxiv.org (arxiv.org)|128.84.21.199|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 13661300 (13M) [application/pdf]\n",
      "Saving to: ‘data/llama2.pdf’\n",
      "\n",
      "data/llama2.pdf     100%[===================>]  13.03M   521KB/s    in 42s     \n",
      "\n",
      "2023-09-28 12:57:20 (320 KB/s) - ‘data/llama2.pdf’ saved [13661300/13661300]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dada361-4ac5-44a9-a29c-ae1aa8f5af78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from llama_index import Document\n",
    "from llama_hub.file.pymu_pdf.base import PyMuPDFReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ed22d71-2dfb-4c77-9511-57166a3de6d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PyMuPDFReader()\n",
    "docs0 = loader.load(file_path=Path(\"./data/llama2.pdf\"))\n",
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "docs = [Document(text=doc_text)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e63e6b28-22ae-4af7-9a1d-b2dcd7fafa8f",
   "metadata": {},
   "source": [
    "Here we try out different chunk sizes: 128, 256, 512, and 1024."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7081194a-ede7-478e-bff2-23e89e23ef16",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Chunk Size: 128\n",
      "Chunk Size: 256\n",
      "Chunk Size: 512\n",
      "Chunk Size: 1024\n"
     ]
    }
   ],
   "source": [
    "# initialize service context (set chunk size)\n",
    "llm = OpenAI(model=\"gpt-4\")\n",
    "chunk_sizes = [128, 256, 512, 1024]\n",
    "service_contexts = []\n",
    "nodes_list = []\n",
    "vector_indices = []\n",
    "query_engines = []\n",
    "for chunk_size in chunk_sizes:\n",
    "    print(f\"Chunk Size: {chunk_size}\")\n",
    "    service_context = ServiceContext.from_defaults(\n",
    "        chunk_size=chunk_size, llm=llm\n",
    "    )\n",
    "    service_contexts.append(service_context)\n",
    "    nodes = service_context.node_parser.get_nodes_from_documents(docs)\n",
    "\n",
    "    # add chunk size to nodes to track later\n",
    "    for node in nodes:\n",
    "        node.metadata[\"chunk_size\"] = chunk_size\n",
    "        node.excluded_embed_metadata_keys = [\"chunk_size\"]\n",
    "        node.excluded_llm_metadata_keys = [\"chunk_size\"]\n",
    "\n",
    "    nodes_list.append(nodes)\n",
    "\n",
    "    # build vector index\n",
    "    vector_index = VectorStoreIndex(nodes)\n",
    "    vector_indices.append(vector_index)\n",
    "\n",
    "    # query engines\n",
    "    query_engines.append(vector_index.as_query_engine())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50b9ac92-0ab2-4306-88ca-48f7e06ee763",
   "metadata": {},
   "source": [
    "## Define Ensemble Retriever\n",
    "\n",
    "We setup an \"ensemble\" retriever primarily using our recursive retrieval abstraction. This works like the following:\n",
    "- Define a separate `IndexNode` corresponding to the vector retriever for each chunk size (retriever for chunk size 128, retriever for chunk size 256, and more)\n",
    "- Put all IndexNodes into a single `SummaryIndex` - when the corresponding retriever is called, *all* nodes are returned.\n",
    "- Define a Recursive Retriever, with the root node being the summary index retriever. This will first fetch all nodes from the summary index retriever, and then recursively call the vector retriever for each chunk size.\n",
    "- Rerank the final results.\n",
    "\n",
    "The end result is that all vector retrievers are called when a query is run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbca69b4-d8d5-4dcb-af33-f9ed4a91ec05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# try ensemble retrieval\n",
    "\n",
    "from llama_index.tools import RetrieverTool\n",
    "from llama_index.schema import IndexNode\n",
    "\n",
    "# retriever_tools = []\n",
    "retriever_dict = {}\n",
    "retriever_nodes = []\n",
    "for chunk_size, vector_index in zip(chunk_sizes, vector_indices):\n",
    "    node_id = f\"chunk_{chunk_size}\"\n",
    "    node = IndexNode(\n",
    "        text=(\n",
    "            \"Retrieves relevant context from the Llama 2 paper (chunk size\"\n",
    "            f\" {chunk_size})\"\n",
    "        ),\n",
    "        index_id=node_id,\n",
    "    )\n",
    "    retriever_nodes.append(node)\n",
    "    retriever_dict[node_id] = vector_index.as_retriever()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "394c1883-4a85-4908-872f-d32072bfe63a",
   "metadata": {},
   "source": [
    "Define recursive retriever."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c9eaa6f-8f11-4380-b3c6-79092f17def3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.selectors.pydantic_selectors import PydanticMultiSelector\n",
    "\n",
    "# from llama_index.retrievers import RouterRetriever\n",
    "from llama_index.retrievers import RecursiveRetriever\n",
    "from llama_index import SummaryIndex\n",
    "\n",
    "# the derived retriever will just retrieve all nodes\n",
    "summary_index = SummaryIndex(retriever_nodes)\n",
    "\n",
    "retriever = RecursiveRetriever(\n",
    "    root_id=\"root\",\n",
    "    retriever_dict={\"root\": summary_index.as_retriever(), **retriever_dict},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b8317e0-2249-4d83-9ff0-dbd6701e25ec",
   "metadata": {},
   "source": [
    "Let's test the retriever on a sample query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c72c61c-d4f7-4159-bb80-1989468ab61c",
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes = await retriever.aretrieve(\n",
    "    \"Tell me about the main aspects of safety fine-tuning\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590ed8bc-83ad-4851-9ec6-bfbbdf3ff38d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of nodes: {len(nodes)}\")\n",
    "for node in nodes:\n",
    "    print(node.node.metadata[\"chunk_size\"])\n",
    "    print(node.node.get_text())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b7e39a9-b031-46bb-b835-e09ed7bec3ee",
   "metadata": {},
   "source": [
    "Define reranker to process the final retrieved set of nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f26c527-17d2-4d4e-a6ee-8ea878ef8742",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define reranker\n",
    "from llama_index.postprocessor import (\n",
    "    LLMRerank,\n",
    "    SentenceTransformerRerank,\n",
    "    CohereRerank,\n",
    ")\n",
    "\n",
    "# reranker = LLMRerank()\n",
    "# reranker = SentenceTransformerRerank(top_n=10)\n",
    "reranker = CohereRerank(top_n=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fcbee6b-aa93-4144-95b7-f5542c9c689e",
   "metadata": {},
   "source": [
    "Define retriever query engine to integrate the recursive retriever + reranker together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "828589ef-d062-40dc-8a4b-245190769445",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define RetrieverQueryEngine\n",
    "from llama_index.query_engine import RetrieverQueryEngine\n",
    "\n",
    "query_engine = RetrieverQueryEngine(retriever, node_postprocessors=[reranker])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e3c341-e66d-4950-88d5-6411699d064b",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = query_engine.query(\n",
    "    \"Tell me about the main aspects of safety fine-tuning\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aa680dd-03a0-4a76-b456-c4ef0136fdc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "display_response(\n",
    "    response, show_source=True, source_length=500, show_source_metadata=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7850424f-84fc-4ea3-bb6a-18b4bc1d6dd5",
   "metadata": {},
   "source": [
    "### Analyzing the Relative Importance of each Chunk\n",
    "\n",
    "One interesting property of ensemble-based retrieval is that through reranking, we can actually use the ordering of chunks in the final retrieved set to determine the importance of each chunk size. For instance, if certain chunk sizes are always ranked near the top, then those are probably more relevant to the query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a7a8303-be94-45c5-8bc5-13ec8c7f1694",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the average precision for each chunk size based on positioning in combined ranking\n",
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "def mrr_all(metadata_values, metadata_key, source_nodes):\n",
    "    # source nodes is a ranked list\n",
    "    # go through each value, find out positioning in source_nodes\n",
    "    value_to_mrr_dict = {}\n",
    "    for metadata_value in metadata_values:\n",
    "        mrr = 0\n",
    "        for idx, source_node in enumerate(source_nodes):\n",
    "            if source_node.node.metadata[metadata_key] == metadata_value:\n",
    "                mrr = 1 / (idx + 1)\n",
    "                break\n",
    "            else:\n",
    "                continue\n",
    "\n",
    "        # normalize AP, set in dict\n",
    "        value_to_mrr_dict[metadata_value] = mrr\n",
    "\n",
    "    df = pd.DataFrame(value_to_mrr_dict, index=[\"MRR\"])\n",
    "    df.style.set_caption(\"Mean Reciprocal Rank\")\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adebbb82-764e-4b45-933e-84bf4ad64d40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean Reciprocal Rank for each Chunk Size\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>128</th>\n",
       "      <th>256</th>\n",
       "      <th>512</th>\n",
       "      <th>1024</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>MRR</th>\n",
       "      <td>0.333333</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.25</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         128   256   512   1024\n",
       "MRR  0.333333   1.0   0.5  0.25"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Compute the Mean Reciprocal Rank for each chunk size (higher is better)\n",
    "# we can see that chunk size of 256 has the highest ranked results.\n",
    "print(\"Mean Reciprocal Rank for each Chunk Size\")\n",
    "mrr_all(chunk_sizes, \"chunk_size\", response.source_nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b27a2f3c-55ce-4fa6-a15a-be539723a967",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "\n",
    "We more rigorously evaluate how well an ensemble retriever works compared to the \"baseline\" retriever.\n",
    "\n",
    "We define/load an eval benchmark dataset and then run different evaluations over it.\n",
    "\n",
    "**WARNING**: This can be *expensive*, especially with GPT-4. Use caution and tune the sample size to fit your budget."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4d66b14-4f38-4b61-809c-f603d7e09ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation import (\n",
    "    DatasetGenerator,\n",
    "    QueryResponseDataset,\n",
    ")\n",
    "from llama_index import ServiceContext\n",
    "from llama_index.llms import OpenAI\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b5a0d4e-7c0a-40f2-be5c-9dc1297483fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: run this if the dataset isn't already saved\n",
    "eval_service_context = ServiceContext.from_defaults(llm=OpenAI(model=\"gpt-4\"))\n",
    "# generate questions from the largest chunks (1024)\n",
    "dataset_generator = DatasetGenerator(\n",
    "    nodes_list[-1],\n",
    "    service_context=eval_service_context,\n",
    "    show_progress=True,\n",
    "    num_questions_per_chunk=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10b97355-de34-4840-a68f-4d137ab1b850",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset = await dataset_generator.agenerate_dataset_from_nodes(num=60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ab72dc2-0d17-4925-83ca-a0630de28349",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset.save_json(\"data/llama2_eval_qr_dataset.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7fd120e-36c2-4d20-8fca-d3e783756879",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optional\n",
    "eval_dataset = QueryResponseDataset.from_json(\n",
    "    \"data/llama2_eval_qr_dataset.json\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30f8b3a4-7824-4924-848f-fe8155291f80",
   "metadata": {},
   "source": [
    "### Compare Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d4e2e1a-a7cf-471a-b786-2645cfb327c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0cbc18a-9cec-4c29-a5e7-c0ba3752118e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation import (\n",
    "    CorrectnessEvaluator,\n",
    "    SemanticSimilarityEvaluator,\n",
    "    RelevancyEvaluator,\n",
    "    FaithfulnessEvaluator,\n",
    "    PairwiseComparisonEvaluator,\n",
    ")\n",
    "\n",
    "# NOTE: can uncomment other evaluators\n",
    "evaluator_c = CorrectnessEvaluator(service_context=eval_service_context)\n",
    "evaluator_s = SemanticSimilarityEvaluator(service_context=eval_service_context)\n",
    "evaluator_r = RelevancyEvaluator(service_context=eval_service_context)\n",
    "evaluator_f = FaithfulnessEvaluator(service_context=eval_service_context)\n",
    "\n",
    "pairwise_evaluator = PairwiseComparisonEvaluator(\n",
    "    service_context=eval_service_context\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a96ca32-9196-43e1-b82a-403becc2d56e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.evaluation.eval_utils import get_responses, get_results_df\n",
    "from llama_index.evaluation import BatchEvalRunner\n",
    "\n",
    "max_samples = 60\n",
    "\n",
    "eval_qs = eval_dataset.questions\n",
    "qr_pairs = eval_dataset.qr_pairs\n",
    "ref_response_strs = [r for (_, r) in qr_pairs]\n",
    "\n",
    "# resetup base query engine and ensemble query engine\n",
    "# base query engine\n",
    "base_query_engine = vector_indices[-1].as_query_engine(similarity_top_k=2)\n",
    "# ensemble query engine\n",
    "reranker = CohereRerank(top_n=4)\n",
    "query_engine = RetrieverQueryEngine(retriever, node_postprocessors=[reranker])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bc3a77d-b513-4df8-a0eb-8543d86eb8ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_pred_responses = get_responses(\n",
    "    eval_qs[:max_samples], base_query_engine, show_progress=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1401e805-97d3-447a-8460-d23c664bbcb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_responses = get_responses(\n",
    "    eval_qs[:max_samples], query_engine, show_progress=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "023e8638-6d52-4d19-b0d8-99f43bfa5ef7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "pred_response_strs = [str(p) for p in pred_responses]\n",
    "base_pred_response_strs = [str(p) for p in base_pred_responses]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "078f6459-c38d-4d3f-a53b-436b9d1d86b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator_dict = {\n",
    "    \"correctness\": evaluator_c,\n",
    "    \"faithfulness\": evaluator_f,\n",
    "    # \"relevancy\": evaluator_r,\n",
    "    \"semantic_similarity\": evaluator_s,\n",
    "}\n",
    "batch_runner = BatchEvalRunner(evaluator_dict, workers=1, show_progress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e88130b1-8938-4ba7-a91c-ee9c0e19a689",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_results = await batch_runner.aevaluate_responses(\n",
    "    queries=eval_qs[:max_samples],\n",
    "    responses=pred_responses[:max_samples],\n",
    "    reference=ref_response_strs[:max_samples],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8becc3de-2530-4e64-bafe-180fbd64a11a",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_eval_results = await batch_runner.aevaluate_responses(\n",
    "    queries=eval_qs[:max_samples],\n",
    "    responses=base_pred_responses[:max_samples],\n",
    "    reference=ref_response_strs[:max_samples],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9153a9f-109c-437e-a860-9e2346859659",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>names</th>\n",
       "      <th>correctness</th>\n",
       "      <th>faithfulness</th>\n",
       "      <th>semantic_similarity</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ensemble Retriever</td>\n",
       "      <td>4.375000</td>\n",
       "      <td>0.983333</td>\n",
       "      <td>0.964546</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Base Retriever</td>\n",
       "      <td>4.066667</td>\n",
       "      <td>0.983333</td>\n",
       "      <td>0.956692</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                names  correctness  faithfulness  semantic_similarity\n",
       "0  Ensemble Retriever     4.375000      0.983333             0.964546\n",
       "1      Base Retriever     4.066667      0.983333             0.956692"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results_df = get_results_df(\n",
    "    [eval_results, base_eval_results],\n",
    "    [\"Ensemble Retriever\", \"Base Retriever\"],\n",
    "    [\"correctness\", \"faithfulness\", \"semantic_similarity\"],\n",
    ")\n",
    "display(results_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33aa47ef-df96-4e58-8960-92d60091d6ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_runner = BatchEvalRunner(\n",
    "    {\"pairwise\": pairwise_evaluator}, workers=3, show_progress=True\n",
    ")\n",
    "\n",
    "pairwise_eval_results = await batch_runner.aevaluate_response_strs(\n",
    "    queries=eval_qs[:max_samples],\n",
    "    response_strs=pred_response_strs[:max_samples],\n",
    "    reference=base_pred_response_strs[:max_samples],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f1c839-ac5e-4458-af05-c6ed4a2db7d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>names</th>\n",
       "      <th>pairwise</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Pairwise Comparison</td>\n",
       "      <td>0.5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 names  pairwise\n",
       "0  Pairwise Comparison       0.5"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df = get_results_df(\n",
    "    [eval_results, base_eval_results],\n",
    "    [\"Ensemble Retriever\", \"Base Retriever\"],\n",
    "    [\"pairwise\"],\n",
    ")\n",
    "display(results_df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
